{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Tutorial on training a HTS-AT model for audio classification on the ESC-50 Dataset\n",
    "\n",
    "Referece: \n",
    "\n",
    "[HTS-AT: A Hierarchical Token-Semantic Audio Transformer for Sound Classification and Detection, ICASSP 2022](https://arxiv.org/abs/2202.00874)\n",
    "\n",
    "Following the HTS-AT's paper, in this tutorial, we would show how to use the HST-AT in the training of the ESC-50 Dataset.\n",
    "\n",
    "The [ESC-50 dataset](https://github.com/karolpiczak/ESC-50) is a labeled collection of 2000 environmental audio recordings suitable for benchmarking methods of environmental sound classification. The dataset consists of 5-second-long recordings organized into 50 semantical classes (with 40 examples per class) loosely arranged into 5 major categories\n",
    "\n",
    "Before running this tutorial, please make sure that you install the below packages by following steps:\n",
    "\n",
    "1. download [the codebase](https://github.com/RetroCirce/HTS-Audio-Transformer), and put this tutorial notebook inside the codebase folder.\n",
    "\n",
    "2. In the github code folder:\n",
    "\n",
    "    > pip install -r requirements.txt\n",
    "\n",
    "3. We do not include the installation of PyTorch in the requirment, since different machines require different vereions of CUDA and Toolkits. So make sure you install the PyTorch from [the official guidance](https://pytorch.org/).\n",
    "\n",
    "4. Install the 'SOX' and the 'ffmpeg', we recommend that you run this code in Linux inside the Conda environment. In that, you can install them by:\n",
    "\n",
    "    > sudo apt install sox\n",
    "    \n",
    "    > conda install -c conda-forge ffmpeg\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# import basic packages\n",
    "import os\n",
    "import numpy as np\n",
    "import wget\n",
    "import sys\n",
    "import gdown\n",
    "import zipfile\n",
    "import librosa\n",
    "# in the notebook, we only can use one GPU\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"0\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Build the workspace and download the needed files\n",
    "\n",
    "def create_path(path):\n",
    "    if not os.path.exists(path):\n",
    "        os.mkdir(path)\n",
    "\n",
    "workspace = \"./workspace\"\n",
    "dataset_path = os.path.join(workspace, \"esc-50\")\n",
    "checkpoint_path = os.path.join(workspace, \"ckpt\")\n",
    "esc_raw_path = os.path.join(dataset_path, 'raw')\n",
    "\n",
    "\n",
    "create_path(workspace)\n",
    "create_path(dataset_path)\n",
    "create_path(checkpoint_path)\n",
    "create_path(esc_raw_path)\n",
    "\n",
    "\n",
    "# download the esc-50 dataset\n",
    "\n",
    "if not os.path.exists(os.path.join(dataset_path, 'ESC-50-master.zip')):\n",
    "    print(\"-------------Downloading ESC-50 Dataset-------------\")\n",
    "    wget.download('https://github.com/karoldvl/ESC-50/archive/master.zip', out=dataset_path)\n",
    "    with zipfile.ZipFile(os.path.join(dataset_path, 'ESC-50-master.zip'), 'r') as zip_ref:\n",
    "        zip_ref.extractall(esc_raw_path)\n",
    "    print(\"-------------Success-------------\")\n",
    "\n",
    "if not os.path.exists(os.path.join(checkpoint_path,'htsat_audioset_pretrain.ckpt')):\n",
    "    gdown.download(id='1OK8a5XuMVLyeVKF117L8pfxeZYdfSDZv', output=os.path.join(checkpoint_path,'htsat_audioset_pretrain.ckpt'))\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Process ESC-50 Dataset\n",
    "meta_path = os.path.join(esc_raw_path, 'ESC-50-master', 'meta', 'esc50.csv')\n",
    "audio_path = os.path.join(esc_raw_path, 'ESC-50-master', 'audio')\n",
    "resample_path = os.path.join(dataset_path, 'resample')\n",
    "savedata_path = os.path.join(dataset_path, 'esc-50-data.npy')\n",
    "create_path(resample_path)\n",
    "\n",
    "meta = np.loadtxt(meta_path , delimiter=',', dtype='str', skiprows=1)\n",
    "audio_list = os.listdir(audio_path)\n",
    "\n",
    "# resample\n",
    "print(\"-------------Resample ESC-50-------------\")\n",
    "for f in audio_list:\n",
    "    full_f = os.path.join(audio_path, f)\n",
    "    resample_f = os.path.join(resample_path, f)\n",
    "    if not os.path.exists(resample_f):\n",
    "        os.system('sox -V1 ' + full_f + ' -r 32000 ' + resample_f)\n",
    "print(\"-------------Success-------------\")\n",
    "\n",
    "print(\"-------------Build Dataset-------------\")\n",
    "output_dict = [[] for _ in range(5)]\n",
    "for label in meta:\n",
    "    name = label[0]\n",
    "    fold = label[1]\n",
    "    target = label[2]\n",
    "    y, sr = librosa.load(os.path.join(resample_path, name), sr = None)\n",
    "    output_dict[int(fold) - 1].append(\n",
    "        {\n",
    "            \"name\": name,\n",
    "            \"target\": int(target),\n",
    "            \"waveform\": y\n",
    "        }\n",
    "    )\n",
    "np.save(savedata_path, output_dict)\n",
    "print(\"-------------Success-------------\")\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load the model package\n",
    "import torch\n",
    "from torch.utils.data import DataLoader\n",
    "from torch.utils.data.distributed import DistributedSampler\n",
    "import pytorch_lightning as pl\n",
    "from pytorch_lightning.callbacks import ModelCheckpoint\n",
    "import warnings\n",
    "\n",
    "from utils import create_folder, dump_config, process_idc\n",
    "import esc_config as config\n",
    "from sed_model import SEDWrapper, Ensemble_SEDWrapper\n",
    "from data_generator import ESC_Dataset\n",
    "from model.htsat import HTSAT_Swin_Transformer\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Data Preparation\n",
    "class data_prep(pl.LightningDataModule):\n",
    "    def __init__(self, train_dataset, eval_dataset, device_num):\n",
    "        super().__init__()\n",
    "        self.train_dataset = train_dataset\n",
    "        self.eval_dataset = eval_dataset\n",
    "        self.device_num = device_num\n",
    "\n",
    "    def train_dataloader(self):\n",
    "        train_sampler = DistributedSampler(self.train_dataset, shuffle = False) if self.device_num > 1 else None\n",
    "        train_loader = DataLoader(\n",
    "            dataset = self.train_dataset,\n",
    "            num_workers = config.num_workers,\n",
    "            batch_size = config.batch_size // self.device_num,\n",
    "            shuffle = False,\n",
    "            sampler = train_sampler\n",
    "        )\n",
    "        return train_loader\n",
    "    def val_dataloader(self):\n",
    "        eval_sampler = DistributedSampler(self.eval_dataset, shuffle = False) if self.device_num > 1 else None\n",
    "        eval_loader = DataLoader(\n",
    "            dataset = self.eval_dataset,\n",
    "            num_workers = config.num_workers,\n",
    "            batch_size = config.batch_size // self.device_num,\n",
    "            shuffle = False,\n",
    "            sampler = eval_sampler\n",
    "        )\n",
    "        return eval_loader\n",
    "    def test_dataloader(self):\n",
    "        test_sampler = DistributedSampler(self.eval_dataset, shuffle = False) if self.device_num > 1 else None\n",
    "        test_loader = DataLoader(\n",
    "            dataset = self.eval_dataset,\n",
    "            num_workers = config.num_workers,\n",
    "            batch_size = config.batch_size // self.device_num,\n",
    "            shuffle = False,\n",
    "            sampler = test_sampler\n",
    "        )\n",
    "        return test_loader\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Set the workspace\n",
    "device_num = torch.cuda.device_count()\n",
    "print(\"each batch size:\", config.batch_size // device_num)\n",
    "\n",
    "full_dataset = np.load(os.path.join(config.dataset_path, \"esc-50-data.npy\"), allow_pickle = True)\n",
    "\n",
    "# set exp folder\n",
    "exp_dir = os.path.join(config.workspace, \"results\", config.exp_name)\n",
    "checkpoint_dir = os.path.join(config.workspace, \"results\", config.exp_name, \"checkpoint\")\n",
    "if not config.debug:\n",
    "    create_folder(os.path.join(config.workspace, \"results\"))\n",
    "    create_folder(exp_dir)\n",
    "    create_folder(checkpoint_dir)\n",
    "    dump_config(config, os.path.join(exp_dir, config.exp_name), False)\n",
    "\n",
    "print(\"Using ESC\")\n",
    "dataset = ESC_Dataset(\n",
    "    dataset = full_dataset,\n",
    "    config = config,\n",
    "    eval_mode = False\n",
    ")\n",
    "eval_dataset = ESC_Dataset(\n",
    "    dataset = full_dataset,\n",
    "    config = config,\n",
    "    eval_mode = True\n",
    ")\n",
    "\n",
    "audioset_data = data_prep(dataset, eval_dataset, device_num)\n",
    "checkpoint_callback = ModelCheckpoint(\n",
    "    monitor = \"acc\",\n",
    "    filename='l-{epoch:d}-{acc:.3f}',\n",
    "    save_top_k = 20,\n",
    "    mode = \"max\"\n",
    ")\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Set the Trainer\n",
    "trainer = pl.Trainer(\n",
    "    deterministic=False,\n",
    "    default_root_dir = checkpoint_dir,\n",
    "    gpus = device_num, \n",
    "    val_check_interval = 1.0,\n",
    "    max_epochs = config.max_epoch,\n",
    "    auto_lr_find = True,    \n",
    "    sync_batchnorm = True,\n",
    "    callbacks = [checkpoint_callback],\n",
    "    accelerator = \"ddp\" if device_num > 1 else None,\n",
    "    num_sanity_val_steps = 0,\n",
    "    resume_from_checkpoint = None, \n",
    "    replace_sampler_ddp = False,\n",
    "    gradient_clip_val=1.0\n",
    ")\n",
    "\n",
    "sed_model = HTSAT_Swin_Transformer(\n",
    "    spec_size=config.htsat_spec_size,\n",
    "    patch_size=config.htsat_patch_size,\n",
    "    in_chans=1,\n",
    "    num_classes=config.classes_num,\n",
    "    window_size=config.htsat_window_size,\n",
    "    config = config,\n",
    "    depths = config.htsat_depth,\n",
    "    embed_dim = config.htsat_dim,\n",
    "    patch_stride=config.htsat_stride,\n",
    "    num_heads=config.htsat_num_head\n",
    ")\n",
    "\n",
    "model = SEDWrapper(\n",
    "    sed_model = sed_model, \n",
    "    config = config,\n",
    "    dataset = dataset\n",
    ")\n",
    "\n",
    "if config.resume_checkpoint is not None:\n",
    "    print(\"Load Checkpoint from \", config.resume_checkpoint)\n",
    "    ckpt = torch.load(config.resume_checkpoint, map_location=\"cpu\")\n",
    "    ckpt[\"state_dict\"].pop(\"sed_model.head.weight\")\n",
    "    ckpt[\"state_dict\"].pop(\"sed_model.head.bias\")\n",
    "    # finetune on the esc and spv2 dataset\n",
    "    ckpt[\"state_dict\"].pop(\"sed_model.tscam_conv.weight\")\n",
    "    ckpt[\"state_dict\"].pop(\"sed_model.tscam_conv.bias\")\n",
    "    model.load_state_dict(ckpt[\"state_dict\"], strict=False)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Training the model\n",
    "# You can set different fold index by setting 'esc_fold' to any number from 0-4 in esc_config.py\n",
    "trainer.fit(model, audioset_data)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Now Let us Check the Result\n",
    "\n",
    "Find the path of your saved checkpoint and paste it in the below variable.\n",
    "Then you are able to follow the below code for checking the prediction result of any sample you like."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# infer the single data to check the result\n",
    "# get a model you saved\n",
    "model_path = 'paste your saved model path'\n",
    "\n",
    "# get the groundtruth\n",
    "meta = np.loadtxt(meta_path , delimiter=',', dtype='str', skiprows=1)\n",
    "gd = {}\n",
    "for label in meta:\n",
    "    name = label[0]\n",
    "    target = label[2]\n",
    "    gd[name] = target\n",
    "\n",
    "class Audio_Classification:\n",
    "    def __init__(self, model_path, config):\n",
    "        super().__init__()\n",
    "\n",
    "        self.device = torch.device('cuda')\n",
    "        self.sed_model = HTSAT_Swin_Transformer(\n",
    "            spec_size=config.htsat_spec_size,\n",
    "            patch_size=config.htsat_patch_size,\n",
    "            in_chans=1,\n",
    "            num_classes=config.classes_num,\n",
    "            window_size=config.htsat_window_size,\n",
    "            config = config,\n",
    "            depths = config.htsat_depth,\n",
    "            embed_dim = config.htsat_dim,\n",
    "            patch_stride=config.htsat_stride,\n",
    "            num_heads=config.htsat_num_head\n",
    "        )\n",
    "        ckpt = torch.load(model_path, map_location=\"cpu\")\n",
    "        temp_ckpt = {}\n",
    "        for key in ckpt[\"state_dict\"]:\n",
    "            temp_ckpt[key[10:]] = ckpt['state_dict'][key]\n",
    "        self.sed_model.load_state_dict(temp_ckpt)\n",
    "        self.sed_model.to(self.device)\n",
    "        self.sed_model.eval()\n",
    "\n",
    "\n",
    "    def predict(self, audiofile):\n",
    "\n",
    "        if audiofile:\n",
    "            waveform, sr = librosa.load(audiofile, sr=32000)\n",
    "\n",
    "            with torch.no_grad():\n",
    "                x = torch.from_numpy(waveform).float().to(self.device)\n",
    "                output_dict = self.sed_model(x[None, :], None, True)\n",
    "                pred = output_dict['clipwise_output']\n",
    "                pred_post = pred[0].detach().cpu().numpy()\n",
    "                pred_label = np.argmax(pred_post)\n",
    "                pred_prob = np.max(pred_post)\n",
    "            return pred_label, pred_prob\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Inference\n",
    "Audiocls = Audio_Classification(model_path, config)\n",
    "\n",
    "# pick any audio you like in the ESC-50 testing set (cross-validation)\n",
    "pred_label, pred_prob = Audiocls.predict(\"./workspace/esc-50/raw/ESC-50-master/audio/1-7456-A-13.wav\")\n",
    "\n",
    "print('Audiocls predict output: ', pred_label, pred_prob, gd[\"1-7456-A-13.wav\"])"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.8.13 ('kechen_py38')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "cb1a0df39641c41734bdd2d42699ec57167c4cf18fd061cdef52c16cce6262af"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
